from py65emu.cpu import CPU
from py65emu.mmu import MMU
import time
import string
import os
import sys
from keyboard import is_pressed
import pygame
from pynput import keyboard
pygame.init()
font = pygame.font.Font("displayfont.ttf", 30) 
w=640
h=640
screen = pygame.display.set_mode((w, h))
arg = sys.argv

key = bytes(0)
def on_press(keyy):
    global key
    try:
        # Check if the key is a character key and print its representation
        # This will be case-sensitive for letter keys
        key = keyy.char.encode()
    except AttributeError:
        if keyy == keyboard.Key.backspace:
            key = b"\x10"
        if keyy == keyboard.Key.delete:
            key = b"\x11"
        if keyy == keyboard.Key.enter:
            key = b"\x1F"
        if keyy == keyboard.Key.space:
            key = b"\x20"

def on_release(keyy):
    global key
    try:
        if key == keyy.char.encode():
            key = bytes(b"\00")
    except AttributeError:
        key = bytes(b"\00")

# Create a listener for keyboard events
# The on_press and on_release functions are called asynchronously
print("Keyboard listener stopped.")

if len(arg) > 2:
    fr = open(sys.argv[2], "rb")  # Open your rom
else:
    fr = open("rom.bin", "rb")
if len(arg) > 1:
    f = open(sys.argv[1], "rb")  # Open your rom
else:
    f = bytes()

# define your blocks of memory.  Each tuple is
# (start_address, length, readOnly=True, value=None, valueOffset=0)
m = MMU([   
        (0x00, 0x0380), # Create RAM with 512 bytes
        (0x0380, 0x80, True, fr), # Create ROM starting at 0x1000 with your program.
        (0x0400, 0x200, True, f),
        (0x0600, 0xffff, True)
])

# Create the CPU with the MMU and the starting program counter address
# You can also optionally pass in a value for stack_page, which defaults
# to 1, meaning the stack will be from 0x100-0x1ff.  As far as I know this
# is true for all 6502s, but for instance in the 6507 used by the Atari
# 2600 it is in the zero page, stack_page=0.
c = CPU(m, 0x0380,2)
"""
# Do this to execute one instruction
c.step()

# You can check the registers and memory values to determine what has changed
print(c.r.a) 	# A register
print(c.r.x) 	# X register
print(c.r.y) 	# Y register
print(c.r.s) 	# Stack Pointer
print(c.r.pc) 	# Program Counter

print(c.cc)     # Print the number of cycles that passed during the last step.
                # This number resets for each call to `.step()`

print(c.r.getFlag('C')) # Get the value of a flag from the flag register.

print(m.read(0x100)) # Read a value from memory
"""
screen.fill(0)
os.system("cls")
with keyboard.Listener(on_press=on_press, on_release=on_release) as listener:
    # Join the listener thread to the main thread, allowing it to run in the background
    tme_cnt = 0
    new_time = 0
    start_time = 0
    while listener.running:
        m.write(0x300,0)
        if is_pressed("up arrow"):
            m.write(0x300,4)
        if is_pressed("right arrow"):
            m.write(0x300,5)
        if is_pressed("down arrow"):
            m.write(0x300,6)
        if is_pressed("left arrow"):
            m.write(0x300,7)
        if is_pressed("w"):
            m.write(0x300,4)
        if is_pressed("d"):
            m.write(0x300,5)
        if is_pressed("s"):
            m.write(0x300,6)
        if is_pressed("a"):
            m.write(0x300,7)
        if is_pressed("space"):
            m.write(0x300,15)
        if len(key) == 0:
            key = bytes(b"\00")
        m.write(0x301,int.from_bytes(key))
        c.step()
        for x in range(16):
            for xx in range(16):
                memval = m.read(256+x*16+xx)
                if m.read(0x0302) == 0:
                    r = round((memval//32)%8*(255/7))
                    g = round((memval//4)%8*(255/7))
                    b = round(memval%4*(255/3))
                    pygame.draw.rect(screen,(r,g,b),(xx*40,x*40,40,40))
                elif m.read(0x302) == 1:
                    pygame.draw.rect(screen,(0,0,0),(xx*40,x*40,40,40))
                    try:
                        text_surface = font.render(memval.to_bytes().decode(encoding="ascii"), True, (255, 255, 255))
                    except Exception:
                        if memval == 255:
                            pygame.draw.rect(screen,(memval,memval,memval),(xx*40,x*40,40,40))
                        else:
                            if memval != 0:
                                text_surface = font.render("?", True, (255, 255, 255))
                            else:
                                text_surface = font.render("", True, (255, 255, 255))
                    screen.blit(text_surface, (xx*40,x*40))
                elif m.read(0x302) == 2:
                    i = round((memval//128)%2*85)
                    r = round((memval//64)%2*170)+i
                    g = round((memval//32)%2*170)+i
                    b = round((memval//16)%2*170)+i
                    pygame.draw.rect(screen,(r,g,b),(xx*40,x*40,20,40))
                    i = round((memval//8)%2*85)
                    r = round((memval//4)%2*170)+i
                    g = round((memval//2)%2*170)+i
                    b = round(memval%2*170)+i
                    pygame.draw.rect(screen,(r,g,b),(xx*40+20,x*40,20,40))
                elif m.read(0x302) == 3 and m.read(0x303) == 0:
                    yyy = 3
                    for yy in range(2):
                        for y in range(2):
                            if (memval//(64/(2**(yyy*2))))%4 == 0:
                                r = 0
                                g = 0
                                b = 0
                                pygame.draw.rect(screen,(r,g,b),(xx*40+y*20,x*40+yy*20,20,20))
                            elif (memval//(64/(2**(yyy*2))))%4 == 1:
                                r = 255
                                g = 85
                                b = 85
                                pygame.draw.rect(screen,(r,g,b),(xx*40+y*20,x*40+yy*20,20,20))
                            elif (memval//(64/(2**(yyy*2))))%4 == 2:
                                r = 255
                                g = 255
                                b = 85
                                pygame.draw.rect(screen,(r,g,b),(xx*40+y*20,x*40+yy*20,20,20))
                            elif (memval//(64/(2**(yyy*2))))%4 == 3:
                                r = 85
                                g = 255
                                b = 255
                                pygame.draw.rect(screen,(r,g,b),(xx*40+y*20,x*40+yy*20,20,20))
                            yyy -= 1
                elif m.read(0x302) == 3 and m.read(0x303) == 1:
                    yyy = 3
                    for yy in range(2):
                        for y in range(2):
                            if (memval//(64/(2**(yyy*2))))%4 == 0:
                                r = 0
                                g = 0
                                b = 0
                                pygame.draw.rect(screen,(r,g,b),(xx*40+y*20,x*40+yy*20,20,20))
                            elif (memval//(64/(2**(yyy*2))))%4 == 1:
                                r = 85
                                g = 85
                                b = 85
                                pygame.draw.rect(screen,(r,g,b),(xx*40+y*20,x*40+yy*20,20,20))
                            elif (memval//(64/(2**(yyy*2))))%4 == 2:
                                r = 170
                                g = 170
                                b = 170
                                pygame.draw.rect(screen,(r,g,b),(xx*40+y*20,x*40+yy*20,20,20))
                            elif (memval//(64/(2**(yyy*2))))%4 == 3:
                                r = 255
                                g = 255
                                b = 255
                                pygame.draw.rect(screen,(r,g,b),(xx*40+y*20,x*40+yy*20,20,20))
                            yyy -= 1
        pygame.display.flip()
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                exit()
        for x in range(2):
            for xx in range(8):
                for xxx in range(32):
                    print(hex(m.read(xxx+xx*32+x*256)+256)[3:],end=" ")
                print()
            print()
        for xxx in range(32):
            print(hex(m.read(xxx+0*32+3*256)+256)[3:],end=" ")
        print()
        print("a:",c.r.a,"            ")
        print("x:",c.r.x,"            ")
        print("y:",c.r.y,"            ")
        print("^:",c.r.pc,"            ")
        print("v:",c.r.s,"            ")
        tme_cnt += 1
        if tme_cnt == 100:
            tme_cnt = 0
            new_time = round(1/((time.time()-start_time)/100))
            start_time = time.time()
        print("cps:",new_time,"        ")
        print("\033[{row};{col}H".format(row=0, col=0), end="")